{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "dimension = 400\n",
    "vocab = \"EOS abcdefghijklmnopqrstuvwxyz'\"\n",
    "char2idx = {char: idx for idx, char in enumerate(vocab)}\n",
    "idx2char = {idx: char for idx, char in enumerate(vocab)}\n",
    "\n",
    "def text2idx(text):\n",
    "    text = re.sub(r'[^a-z ]', '', text.lower()).strip()\n",
    "    converted = [char2idx[char] for char in text]\n",
    "    return text, converted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "GO = 1\n",
    "PAD = 0\n",
    "EOS = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "\n",
    "train_X, train_Y = [], []\n",
    "text_files = [f for f in os.listdir('spectrogram-train') if f.endswith('.npy')]\n",
    "for fpath in text_files:\n",
    "    try:\n",
    "        splitted = fpath.split('-')\n",
    "        if len(splitted) == 2:\n",
    "            splitted[1] = splitted[1].split('.')[1]\n",
    "            fpath = splitted[0] + '.' + splitted[1]\n",
    "        with open('data/' + fpath.replace('npy', 'txt')) as fopen:\n",
    "            text, converted = text2idx(fopen.read())\n",
    "        w = np.load('spectrogram-train/' + fpath)\n",
    "        if w.shape[1] != dimension:\n",
    "            continue\n",
    "        train_X.append(w)\n",
    "        train_Y.append(converted)\n",
    "    except:\n",
    "        pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_X, test_Y = [], []\n",
    "text_files = [f for f in os.listdir('spectrogram-test') if f.endswith('.npy')]\n",
    "for fpath in text_files:\n",
    "    with open('data/' + fpath.replace('npy', 'txt')) as fopen:\n",
    "        text, converted = text2idx(fopen.read())\n",
    "    w = np.load('spectrogram-test/' + fpath)\n",
    "    if w.shape[1] != dimension:\n",
    "        continue\n",
    "    test_X.append(w)\n",
    "    test_Y.append(converted)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_second_dim(x, desired_size):\n",
    "    padding = tf.tile([[0]], tf.stack([tf.shape(x)[0], desired_size - tf.shape(x)[1]], 0))\n",
    "    return tf.concat([x, padding], 1)\n",
    "\n",
    "_BATCH_NORM_EPSILON = 1e-5\n",
    "_BATCH_NORM_DECAY = 0.997\n",
    "_CONV_FILTERS = 32\n",
    "\n",
    "def batch_norm(inputs, training):\n",
    "    return tf.layers.batch_normalization(\n",
    "      inputs=inputs, momentum=_BATCH_NORM_DECAY, epsilon=_BATCH_NORM_EPSILON,\n",
    "      fused=True, training=training)\n",
    "\n",
    "def _conv_bn_layer(inputs, padding, filters, kernel_size, strides, layer_id,\n",
    "                   training):\n",
    "    inputs = tf.pad(\n",
    "      inputs,\n",
    "      [[0, 0], [padding[0], padding[0]], [padding[1], padding[1]], [0, 0]])\n",
    "    inputs = tf.layers.conv2d(\n",
    "      inputs=inputs, filters=filters, kernel_size=kernel_size, strides=strides,\n",
    "      padding=\"valid\", use_bias=False, activation=tf.nn.relu6,\n",
    "      name=\"cnn_{}\".format(layer_id))\n",
    "    return inputs\n",
    "    #return batch_norm(inputs, training)\n",
    "\n",
    "def _rnn_layer(inputs, rnn_cell, rnn_hidden_size, layer_id, is_batch_norm,\n",
    "               is_bidirectional, training):\n",
    "    if is_batch_norm:\n",
    "        inputs = batch_norm(inputs, training)\n",
    "    \n",
    "    fw_cell = rnn_cell(num_units=rnn_hidden_size,\n",
    "                     name=\"rnn_fw_{}\".format(layer_id))\n",
    "    bw_cell = rnn_cell(num_units=rnn_hidden_size,\n",
    "                     name=\"rnn_bw_{}\".format(layer_id))\n",
    "\n",
    "    if is_bidirectional:\n",
    "        outputs, _ = tf.nn.bidirectional_dynamic_rnn(\n",
    "        cell_fw=fw_cell, cell_bw=bw_cell, inputs=inputs, dtype=tf.float32,\n",
    "        swap_memory=True)\n",
    "        rnn_outputs = tf.concat(outputs, -1)\n",
    "    else:\n",
    "        rnn_outputs = tf.nn.dynamic_rnn(\n",
    "        fw_cell, inputs, dtype=tf.float32, swap_memory=True)\n",
    "\n",
    "    return rnn_outputs\n",
    "\n",
    "class Model:\n",
    "    def __init__(\n",
    "        self,\n",
    "        size_layers,\n",
    "        learning_rate,\n",
    "        num_features,\n",
    "        dropout = 1.0,\n",
    "    ):\n",
    "        self.X = tf.placeholder(tf.float32, [None, None, num_features])\n",
    "        self.label = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y_seq_len = tf.placeholder(tf.int32, [None])\n",
    "        self.training = tf.placeholder(tf.bool, None)\n",
    "        self.Y = tf.sparse_placeholder(tf.int32)\n",
    "        x = tf.expand_dims(self.X, -1)\n",
    "\n",
    "        inputs = _conv_bn_layer(\n",
    "            x, padding=(20, 5), filters=_CONV_FILTERS, kernel_size=(41, 11),\n",
    "            strides=(2, 2), layer_id=1, training=self.training)\n",
    "        \n",
    "        inputs = _conv_bn_layer(\n",
    "            inputs, padding=(10, 5), filters=_CONV_FILTERS, kernel_size=(21, 11),\n",
    "            strides=(2, 1), layer_id=2, training=self.training)\n",
    "        \n",
    "        batch_size = tf.shape(inputs)[0]\n",
    "        feat_size = inputs.get_shape().as_list()[2]\n",
    "        inputs = tf.reshape(\n",
    "            inputs,\n",
    "            [batch_size, -1, feat_size * _CONV_FILTERS // 4])\n",
    "        print(inputs)\n",
    "        \n",
    "        seq_lens = tf.count_nonzero(\n",
    "            tf.reduce_sum(inputs, -1), 1, dtype = tf.int32\n",
    "        ) + 30\n",
    "        filled = tf.fill(tf.shape(seq_lens), tf.shape(inputs)[1])\n",
    "        seq_lens = tf.where(seq_lens > tf.shape(inputs)[1], filled, seq_lens)\n",
    "        \n",
    "        rnn_cell = tf.nn.rnn_cell.GRUCell\n",
    "        for layer_counter in range(5):\n",
    "            is_batch_norm = (layer_counter != 0)\n",
    "            inputs = _rnn_layer(\n",
    "              inputs, rnn_cell, size_layers, layer_counter + 1,\n",
    "              False, True, self.training)\n",
    "        \n",
    "\n",
    "        logits = tf.layers.dense(inputs, len(vocab))\n",
    "        self.logits = logits\n",
    "        time_major = tf.transpose(logits, [1, 0, 2])\n",
    "        decoded, log_prob = tf.nn.ctc_beam_search_decoder(time_major, seq_lens)\n",
    "        decoded = tf.to_int32(decoded[0])\n",
    "        self.preds = tf.sparse_tensor_to_dense(decoded)\n",
    "        self.cost = tf.reduce_mean(\n",
    "            tf.nn.ctc_loss(\n",
    "                self.Y,\n",
    "                time_major,\n",
    "                seq_lens\n",
    "            )\n",
    "        )\n",
    "        self.optimizer = tf.train.AdamOptimizer(\n",
    "            learning_rate = learning_rate\n",
    "        ).minimize(self.cost)\n",
    "        \n",
    "        preds = self.preds[:, :tf.reduce_max(self.Y_seq_len)]\n",
    "        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)\n",
    "        preds = pad_second_dim(preds, tf.reduce_max(self.Y_seq_len))\n",
    "        y_t = tf.cast(preds, tf.int32)\n",
    "        self.prediction = tf.boolean_mask(y_t, masks)\n",
    "        mask_label = tf.boolean_mask(self.label, masks)\n",
    "        self.mask_label = mask_label\n",
    "        correct_pred = tf.equal(self.prediction, mask_label)\n",
    "        correct_index = tf.cast(correct_pred, tf.float32)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py:1735: UserWarning: An interactive session is already active. This can cause out-of-memory errors in some cases. You must explicitly call `InteractiveSession.close()` to release resources held by the other session(s).\n",
      "  warnings.warn('An interactive session is already active. This can '\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(\"Reshape:0\", shape=(?, ?, 1600), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "\n",
    "size_layers = 512\n",
    "learning_rate = 1e-4\n",
    "num_layers = 2\n",
    "batch_size = 64\n",
    "epoch = 20\n",
    "\n",
    "model = Model(size_layers, learning_rate, dimension)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X = tf.keras.preprocessing.sequence.pad_sequences(\n",
    "    train_X, dtype = 'float32', padding = 'post'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_X = tf.keras.preprocessing.sequence.pad_sequences(\n",
    "    test_X, dtype = 'float32', padding = 'post'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_sentence_batch(sentence_batch, pad_int):\n",
    "    padded_seqs = []\n",
    "    seq_lens = []\n",
    "    max_sentence_len = max([len(sentence) for sentence in sentence_batch])\n",
    "    for sentence in sentence_batch:\n",
    "        padded_seqs.append(sentence + [pad_int] * (max_sentence_len - len(sentence)))\n",
    "        seq_lens.append(len(sentence))\n",
    "    return padded_seqs, seq_lens\n",
    "\n",
    "def sparse_tuple_from(sequences, dtype=np.int32):\n",
    "    indices = []\n",
    "    values = []\n",
    "\n",
    "    for n, seq in enumerate(sequences):\n",
    "        indices.extend(zip([n] * len(seq), range(len(seq))))\n",
    "        values.extend(seq)\n",
    "\n",
    "    indices = np.asarray(indices, dtype=np.int64)\n",
    "    values = np.asarray(values, dtype=dtype)\n",
    "    shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1] + 1], dtype=np.int64)\n",
    "\n",
    "    return indices, values, shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [03:06<00:00,  1.22it/s, accuracy=0.755, cost=12.9] \n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:08<00:00,  1.06it/s, accuracy=0.761, cost=13]  \n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, training avg cost 22.632313, training avg accuracy 0.575606\n",
      "epoch 1, testing avg cost 12.900587, testing avg accuracy 0.768059\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [03:23<00:00,  1.24it/s, accuracy=0.777, cost=11.2]\n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:08<00:00,  1.11it/s, accuracy=0.775, cost=11.3]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 2, training avg cost 11.911486, training avg accuracy 0.776706\n",
      "epoch 2, testing avg cost 11.289837, testing avg accuracy 0.782253\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [03:27<00:00,  1.23it/s, accuracy=0.763, cost=10.8]\n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:08<00:00,  1.06it/s, accuracy=0.784, cost=11.1]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 3, training avg cost 11.084230, training avg accuracy 0.783230\n",
      "epoch 3, testing avg cost 11.031971, testing avg accuracy 0.785639\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [03:27<00:00,  1.25it/s, accuracy=0.777, cost=10.5]\n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:08<00:00,  1.07it/s, accuracy=0.771, cost=10.9]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 4, training avg cost 10.702530, training avg accuracy 0.785045\n",
      "epoch 4, testing avg cost 10.779385, testing avg accuracy 0.785688\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [03:27<00:00,  1.20it/s, accuracy=0.77, cost=10.4] \n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:08<00:00,  1.08it/s, accuracy=0.774, cost=10.5]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 5, training avg cost 10.371732, training avg accuracy 0.786856\n",
      "epoch 5, testing avg cost 10.448628, testing avg accuracy 0.785171\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [03:24<00:00,  1.27it/s, accuracy=0.784, cost=10.1]\n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:08<00:00,  1.12it/s, accuracy=0.784, cost=10.5]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 6, training avg cost 10.058434, training avg accuracy 0.790098\n",
      "epoch 6, testing avg cost 10.403775, testing avg accuracy 0.792030\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [03:25<00:00,  1.24it/s, accuracy=0.799, cost=9.36]\n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:08<00:00,  1.11it/s, accuracy=0.783, cost=10.5]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 7, training avg cost 9.571798, training avg accuracy 0.797075\n",
      "epoch 7, testing avg cost 10.324467, testing avg accuracy 0.797482\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [03:25<00:00,  1.26it/s, accuracy=0.849, cost=7.87]\n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:08<00:00,  1.06it/s, accuracy=0.792, cost=10.4]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 8, training avg cost 8.889078, training avg accuracy 0.809279\n",
      "epoch 8, testing avg cost 10.178507, testing avg accuracy 0.798248\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [03:23<00:00,  1.24it/s, accuracy=0.878, cost=6.58]\n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:08<00:00,  1.11it/s, accuracy=0.791, cost=10.6]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 9, training avg cost 7.975466, training avg accuracy 0.827766\n",
      "epoch 9, testing avg cost 10.161058, testing avg accuracy 0.803391\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [03:24<00:00,  1.25it/s, accuracy=0.892, cost=5.13]\n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:08<00:00,  1.11it/s, accuracy=0.792, cost=10.9]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 10, training avg cost 6.821675, training avg accuracy 0.850467\n",
      "epoch 10, testing avg cost 10.476594, testing avg accuracy 0.805270\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [03:25<00:00,  1.24it/s, accuracy=0.921, cost=3.82]\n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:08<00:00,  1.08it/s, accuracy=0.801, cost=11.6]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 11, training avg cost 5.510968, training avg accuracy 0.877792\n",
      "epoch 11, testing avg cost 11.025077, testing avg accuracy 0.808751\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [03:26<00:00,  1.28it/s, accuracy=0.935, cost=2.55]\n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:08<00:00,  1.10it/s, accuracy=0.806, cost=12.9]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 12, training avg cost 4.202572, training avg accuracy 0.906124\n",
      "epoch 12, testing avg cost 12.070560, testing avg accuracy 0.809879\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [03:27<00:00,  1.23it/s, accuracy=0.964, cost=1.55]\n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:08<00:00,  1.09it/s, accuracy=0.803, cost=14.4]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 13, training avg cost 3.001554, training avg accuracy 0.931464\n",
      "epoch 13, testing avg cost 13.581208, testing avg accuracy 0.809106\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [03:27<00:00,  1.23it/s, accuracy=0.95, cost=1.2]  \n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:08<00:00,  1.08it/s, accuracy=0.806, cost=15.2]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 14, training avg cost 2.089147, training avg accuracy 0.950307\n",
      "epoch 14, testing avg cost 14.583699, testing avg accuracy 0.811015\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [03:28<00:00,  1.22it/s, accuracy=0.95, cost=1.05]  \n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:08<00:00,  1.10it/s, accuracy=0.813, cost=16.4]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 15, training avg cost 1.424262, training avg accuracy 0.962460\n",
      "epoch 15, testing avg cost 15.670724, testing avg accuracy 0.812718\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [03:30<00:00,  1.20it/s, accuracy=0.95, cost=0.535] \n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:08<00:00,  1.09it/s, accuracy=0.81, cost=18.3] \n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 16, training avg cost 0.927936, training avg accuracy 0.971807\n",
      "epoch 16, testing avg cost 16.953733, testing avg accuracy 0.812459\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [03:32<00:00,  1.20it/s, accuracy=0.942, cost=0.347]\n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:08<00:00,  1.08it/s, accuracy=0.8, cost=20.1]  \n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 17, training avg cost 0.618884, training avg accuracy 0.977377\n",
      "epoch 17, testing avg cost 18.126024, testing avg accuracy 0.811326\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [03:34<00:00,  1.20it/s, accuracy=0.971, cost=0.107]\n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:08<00:00,  1.07it/s, accuracy=0.809, cost=19]  \n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 18, training avg cost 0.462682, training avg accuracy 0.979819\n",
      "epoch 18, testing avg cost 18.240946, testing avg accuracy 0.813657\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [03:35<00:00,  1.19it/s, accuracy=0.971, cost=0.135]\n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:08<00:00,  1.07it/s, accuracy=0.803, cost=20.3]\n",
      "minibatch loop:   0%|          | 0/206 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 19, training avg cost 0.302372, training avg accuracy 0.982580\n",
      "epoch 19, testing avg cost 19.257372, testing avg accuracy 0.814184\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 206/206 [03:36<00:00,  1.19it/s, accuracy=0.971, cost=0.0488]\n",
      "testing minibatch loop: 100%|██████████| 9/9 [00:08<00:00,  1.07it/s, accuracy=0.818, cost=20.1]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 20, training avg cost 0.188500, training avg accuracy 0.984385\n",
      "epoch 20, testing avg cost 19.263779, testing avg accuracy 0.814071\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "for e in range(epoch):\n",
    "    pbar = tqdm(\n",
    "        range(0, len(train_X), batch_size), desc = 'minibatch loop')\n",
    "    train_cost, train_accuracy, test_cost, test_accuracy = [], [], [], []\n",
    "    for i in pbar:\n",
    "        batch_x = train_X[i : min(i + batch_size, len(train_X))]\n",
    "        y = train_Y[i : min(i + batch_size, len(train_X))]\n",
    "        batch_y = sparse_tuple_from(y)\n",
    "        batch_label, batch_len = pad_sentence_batch(y, 0)\n",
    "        _, cost, accuracy = sess.run(\n",
    "            [model.optimizer, model.cost, model.accuracy],\n",
    "            feed_dict = {model.X: batch_x, model.Y: batch_y, \n",
    "                         model.label: batch_label, model.Y_seq_len: batch_len},\n",
    "        )\n",
    "        train_cost.append(cost)\n",
    "        train_accuracy.append(accuracy)\n",
    "        pbar.set_postfix(cost = cost, accuracy = accuracy)\n",
    "    \n",
    "    pbar = tqdm(\n",
    "        range(0, len(test_X), batch_size), desc = 'testing minibatch loop')\n",
    "    for i in pbar:\n",
    "        batch_x = test_X[i : min(i + batch_size, len(test_X))]\n",
    "        y = test_Y[i : min(i + batch_size, len(test_X))]\n",
    "        batch_y = sparse_tuple_from(y)\n",
    "        batch_label, batch_len = pad_sentence_batch(y, 0)\n",
    "        cost, accuracy = sess.run(\n",
    "            [model.cost, model.accuracy],\n",
    "            feed_dict = {model.X: batch_x, model.Y: batch_y, \n",
    "                         model.label: batch_label, model.Y_seq_len: batch_len},\n",
    "        )\n",
    "        \n",
    "        test_cost.append(cost)\n",
    "        test_accuracy.append(accuracy)\n",
    "        \n",
    "        pbar.set_postfix(cost = cost, accuracy = accuracy)\n",
    "    print('epoch %d, training avg cost %f, training avg accuracy %f'%(e + 1, np.mean(train_cost), \n",
    "                                                                      np.mean(train_accuracy)))\n",
    "    \n",
    "    print('epoch %d, testing avg cost %f, testing avg accuracy %f'%(e + 1, np.mean(test_cost), \n",
    "                                                                    np.mean(test_accuracy)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "real: say the word which\n",
      "predicted: say the word hith\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "random_index = random.randint(0, len(test_X) - 1)\n",
    "batch_x = test_X[random_index : random_index + 1]\n",
    "print(\n",
    "    'real:',\n",
    "    ''.join(\n",
    "        [idx2char[no] for no in test_Y[random_index : random_index + 1][0]]\n",
    "    ),\n",
    ")\n",
    "batch_y = sparse_tuple_from(test_Y[random_index : random_index + 1])\n",
    "pred = sess.run(model.preds, feed_dict = {model.X: batch_x})[0]\n",
    "print('predicted:', ''.join([idx2char[no] for no in pred]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
